{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🏔️ Step-back prompting with workflows for RAG with Argilla"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial will show how to use step-back prompting with LlamaIndex workflows for RAG integrated with Argilla.\n",
    "\n",
    "This prompting approach is based on \"[Take a Step Back: Evoking Reasoning via Abstraction in Large Language Models](https://arxiv.org/abs/2310.06117)\". This paper suggests that the response can be improved by asking the model to take a step back and reason about the context in a more abstract way. This way, the original query is abstracted and used to retrieved the relevant information. Then, this context along with the original context and query are used to generate the final response. \n",
    "\n",
    "[Argilla](https://github.com/argilla-io/argilla) is a collaboration tool for AI engineers and domain experts to build high-quality datasets. By doing this, you can analyze and enhance the quality of your data, leading to improved model performance by incorporating human feedback into the loop. The integration will automatically log the query, response, retrieved contexts with their scores, and the full trace (including spans and events), along with relevant metadata in Argilla. By default, you'll have the ability to rate responses, provide feedback, and evaluate the retrieved contexts, ensuring accuracy and preventing any discrepancies.\n",
    "\n",
    "It includes the following steps:\n",
    "\n",
    "- Setting up the Argilla handler for LlamaIndex.\n",
    "- Designing the step-back workflow.\n",
    "- Run the step-back workflow with LlamaIndex and automatically log the responses to Argilla."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting started\n",
    "\n",
    "### Deploy the Argilla server¶\n",
    "\n",
    "If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up the environment¶\n",
    "\n",
    "To complete this tutorial, you need to install this integration.\n",
    "\n",
    "> Check the GitHub repository [here](https://github.com/argilla-io/argilla-llama-index)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install \"argilla-llama-index>=2.1.0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make the required imports:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import (\n",
    "    Settings,\n",
    "    SimpleDirectoryReader,\n",
    "    VectorStoreIndex,\n",
    ")\n",
    "from llama_index.core.instrumentation import get_dispatcher\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.response_synthesizers import ResponseMode\n",
    "from llama_index.core.schema import NodeWithScore\n",
    "from llama_index.core.workflow import (\n",
    "    Context,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    "    Workflow,\n",
    "    step,\n",
    ")\n",
    "\n",
    "from llama_index.core import get_response_synthesizer\n",
    "from llama_index.core.workflow import Event\n",
    "from llama_index.utils.workflow import draw_all_possible_flows\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "from argilla_llama_index import ArgillaHandler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to set the OpenAI API key. The OpenAI API key is required to run queries using GPT models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set the Argilla's LlamaIndex handler\n",
    "\n",
    "To easily log your data into Argilla within your LlamaIndex workflow, you only need to initialize the Argilla handler and attach it to the Llama Index dispatcher for spans and events. This ensures that the predictions obtained using LlamaIndex are automatically logged to the Argilla instance, along with the useful metadata.\n",
    "\n",
    "- `dataset_name`: The name of the dataset. If the dataset does not exist, it will be created with the specified name. Otherwise, it will be updated.\n",
    "- `api_url`: The URL to connect to the Argilla instance.\n",
    "- `api_key`: The API key to authenticate with the Argilla instance.\n",
    "- `number_of_retrievals`: The number of retrieved documents to be logged. Defaults to 0.\n",
    "- `workspace_name`: The name of the workspace to log the data. By default, the first available workspace.\n",
    "\n",
    "> For more information about the credentials, check the documentation for [users](https://docs.argilla.io/latest/how_to_guides/user/) and [workspaces](https://docs.argilla.io/latest/how_to_guides/workspace/).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "argilla_handler = ArgillaHandler(\n",
    "    dataset_name=\"workflow_llama_index\",\n",
    "    api_url=\"http://localhost:6900\",\n",
    "    api_key=\"argilla.apikey\",\n",
    "    number_of_retrievals=2,\n",
    ")\n",
    "root_dispatcher = get_dispatcher()\n",
    "root_dispatcher.add_span_handler(argilla_handler)\n",
    "root_dispatcher.add_event_handler(argilla_handler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the step-back workflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to define the two events that will be used in the step-back workflow. The `StepBackEvent` that will receive the step-back query, and the `RetriverEvent` that will receive the relevant nodes for the original and step-back queries after the retrieval."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StepBackEvent(Event):\n",
    "    \"\"\"Get the step-back query\"\"\"\n",
    "\n",
    "    step_back_query: str\n",
    "\n",
    "\n",
    "class RetrieverEvent(Event):\n",
    "    \"\"\"Result of running the retrievals\"\"\"\n",
    "\n",
    "    nodes_original: list[NodeWithScore]\n",
    "    nodes_step_back: list[NodeWithScore]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will define the prompts according to the original paper to get the step-back query and then get the final response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "STEP_BACK_TEMPLATE = \"\"\"\n",
    "You are an expert at world knowledge. Your task is to step back and\n",
    "paraphrase a question to a more generic step-back question, which is\n",
    "easier to answer. Here are a few examples:\n",
    "\n",
    "Original Question: Which position did Knox Cunningham hold from May 1955 to Apr 1956?\n",
    "Stepback Question: Which positions have Knox Cunningham held in his career?\n",
    "\n",
    "Original Question: Who was the spouse of Anna Karina from 1968 to 1974?\n",
    "Stepback Question: Who were the spouses of Anna Karina?\n",
    "\n",
    "Original Question: what is the biggest hotel in las vegas nv as of November 28, 1993\n",
    "Stepback Question: what is the size of the hotels in las vegas nv as of November 28, 1993?\n",
    "\n",
    "Original Question: {original_query}\n",
    "Stepback Question:\n",
    "\"\"\"\n",
    "\n",
    "GENERATE_ANSWER_TEMPLATE = \"\"\"\n",
    "You are an expert of world knowledge. I am going to ask you a question.\n",
    "Your response should be comprehensive and not contradicted with the\n",
    "following context if they are relevant. Otherwise, ignore them if they are\n",
    "not relevant.\n",
    "\n",
    "{context_original}\n",
    "{context_step_back}\n",
    "\n",
    "Original Question: {query}\n",
    "Answer:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will define the step-back workflow. In this case, the workflow will be linear. First, we will prompt the LLM to make an abstraction of the original query (step-back prompting). Then, we will retrieve the relevant nodes for the original and step-back queries. Finally, we will prompt the LLM to generate the final response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RAGWorkflow(Workflow):\n",
    "    @step\n",
    "    async def step_back(\n",
    "        self, ctx: Context, ev: StartEvent\n",
    "    ) -> StepBackEvent | None:\n",
    "        \"\"\"Generate the step-back query.\"\"\"\n",
    "        query = ev.get(\"query\")\n",
    "        index = ev.get(\"index\")\n",
    "\n",
    "        if not query:\n",
    "            return None\n",
    "\n",
    "        if not index:\n",
    "            return None\n",
    "\n",
    "        llm = Settings.llm\n",
    "        step_back_query = llm.complete(\n",
    "            prompt=STEP_BACK_TEMPLATE.format(original_query=query),\n",
    "            formatted=True,\n",
    "        )\n",
    "\n",
    "        await ctx.store.set(\"query\", query)\n",
    "        await ctx.store.set(\"index\", index)\n",
    "\n",
    "        return StepBackEvent(step_back_query=str(step_back_query))\n",
    "\n",
    "    @step\n",
    "    async def retrieve(\n",
    "        self, ctx: Context, ev: StepBackEvent\n",
    "    ) -> RetrieverEvent | None:\n",
    "        \"Retrieve the relevant nodes for the original and step-back queries.\"\n",
    "        query = await ctx.store.get(\"query\", default=None)\n",
    "        index = await ctx.store.get(\"index\", default=None)\n",
    "\n",
    "        await ctx.store.set(\"step_back_query\", ev.step_back_query)\n",
    "\n",
    "        retriever = index.as_retriever(similarity_top_k=2)\n",
    "        nodes_step_back = await retriever.aretrieve(ev.step_back_query)\n",
    "        nodes_original = await retriever.aretrieve(query)\n",
    "\n",
    "        return RetrieverEvent(\n",
    "            nodes_original=nodes_original, nodes_step_back=nodes_step_back\n",
    "        )\n",
    "\n",
    "    @step\n",
    "    async def synthesize(self, ctx: Context, ev: RetrieverEvent) -> StopEvent:\n",
    "        \"\"\"Return a response using the contextualized prompt and retrieved nodes.\"\"\"\n",
    "        nodes_original = ev.nodes_original\n",
    "        nodes_step_back = ev.nodes_step_back\n",
    "\n",
    "        context_original = max(\n",
    "            nodes_original, key=lambda node: node.get_score()\n",
    "        ).get_text()\n",
    "        context_step_back = max(\n",
    "            nodes_step_back, key=lambda node: node.get_score()\n",
    "        ).get_text()\n",
    "\n",
    "        query = await ctx.store.get(\"query\", default=None)\n",
    "        formatted_query = GENERATE_ANSWER_TEMPLATE.format(\n",
    "            context_original=context_original,\n",
    "            context_step_back=context_step_back,\n",
    "            query=query,\n",
    "        )\n",
    "\n",
    "        response_synthesizer = get_response_synthesizer(\n",
    "            response_mode=ResponseMode.COMPACT\n",
    "        )\n",
    "\n",
    "        response = response_synthesizer.synthesize(\n",
    "            formatted_query, nodes=ev.nodes_original\n",
    "        )\n",
    "        return StopEvent(result=response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "draw_all_possible_flows(RAGWorkflow, filename=\"step_back_workflow.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![screenshot-workflow.png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the step-back workflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use an example `.txt` file obtained from the [LlamaIndex documentation](https://docs.llamaindex.ai/en/stable/getting_started/starter_example.html). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve the data if needed\n",
    "!mkdir -p ../../data\n",
    "!curl https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt -o ../../data/paul_graham_essay.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's create a LlamaIndex index out of this document. As the highest-rated context for the original and step-back query will be included in the final prompt, we will lower the chuck size and use a `SentenceSplitter`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LLM settings\n",
    "Settings.llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.8)\n",
    "\n",
    "# Load the data and create the index\n",
    "transformations = [\n",
    "    SentenceSplitter(chunk_size=256, chunk_overlap=75),\n",
    "]\n",
    "\n",
    "documents = SimpleDirectoryReader(\"../../data\").load_data()\n",
    "index = VectorStoreIndex.from_documents(\n",
    "    documents=documents,\n",
    "    transformations=transformations,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's run the step-back workflow and make a query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = RAGWorkflow()\n",
    "\n",
    "result = await w.run(query=\"What's Paul's work\", index=index)\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The generated response will be automatically logged in our Argilla instance. Check it out! From Argilla, you can quickly look at your predictions and annotate them so you can combine both synthetic data and human feedback.\n",
    "\n",
    "> You can check [this guide](https://docs.argilla.io/latest/how_to_guides/annotate/) to know how to annotate your data.\n",
    "\n",
    "![UI-screeshot-workflow.png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "\n",
    "Once you've annotated your data, you can retrieve it from Argilla. By integrating human feedback into the process, we’ve guaranteed data quality, making it ready for fine-tuning your model. Moreover, to maintain model performance and prevent data drift, you can set aside a portion of the data for ongoing evaluation over time."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "argilla-llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
